{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "\"\"\"\n",
    "# If you're using Google Colab and not running locally, run this cell.\n",
    "\n",
    "## Install dependencies\n",
    "!pip install wget\n",
    "!apt-get install sox libsndfile1 ffmpeg\n",
    "!pip install text-unidecode\n",
    "!pip install ipython\n",
    "\n",
    "# ## Install NeMo\n",
    "BRANCH = 'main'\n",
    "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@{BRANCH}#egg=nemo_toolkit[asr]\n",
    "\n",
    "## Install TorchAudio\n",
    "!pip install torchaudio -f https://download.pytorch.org/whl/torch_stable.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# End-to-End Speaker Diarization in NeMo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"images/cascaded_diar_diagram.png\" alt=\"Cascaded Diar System \" style=\"width: 800px;\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Traditional cascaded (also referred to as modular or pipelined) speaker diarization systems consist of multiple modules such as a speaker activity detection (SAD) module and a speaker embedding extractor module. Cascaded systems are often time-consuming to develop since each module should be separately trained and optimized."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"images/e2e_diar_diagram.png\" alt=\"End-to-end diarization model diagram\" style=\"width: 800px;\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On the other hand, end-to-end speaker diarization systems pursue a much more simplified version of a system where a single neural network model accepts raw audio signals and outputs speaker activity for each audio frame. Therefore, end-to-end diarization models have an advantage in ease of optimization and depoloyments."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Sortformer Diarization Inference\n",
    "\n",
    "As explained in the [Sortformer Diarization Training](https://github.com/NVIDIA/NeMo/blob/main/tutorials/speaker_tasks/Speaker_Diarization_Training.ipynb) tutorial, Sortformer is trained with Sort-Loss to generate speaker segments in arrival-time order. If a diarization model can generate speaker segments in a pre-defined rule or order, we do not need to match the permutations when we train diarization model with multi-speaker automatic speech recognition (ASR) models or we do not need to match permutations from each window when a diarization model is running in streaming mode where audio chunk sequences are processed, creating a problem of permutation matchin between inference windows. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"images/intro_comparison.png\" alt=\"Intro Comparison\" style=\"width: 800px;\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A toy example for speaker diarization with a Sortformer model "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download a toy example audio file (`an4_diarize_test.wav`) and its ground-truth label file (`an4_diarize_test.rttm`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import wget\n",
    "ROOT = os.getcwd()\n",
    "data_dir = os.path.join(ROOT,'data')\n",
    "os.makedirs(data_dir, exist_ok=True)\n",
    "an4_audio = os.path.join(data_dir,'an4_diarize_test.wav')\n",
    "an4_rttm = os.path.join(data_dir,'an4_diarize_test.rttm')\n",
    "if not os.path.exists(an4_audio):\n",
    "    an4_audio_url = \"https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.wav\"\n",
    "    an4_audio = wget.download(an4_audio_url, data_dir)\n",
    "if not os.path.exists(an4_rttm):\n",
    "    an4_rttm_url = \"https://nemo-public.s3.us-east-2.amazonaws.com/an4_diarize_test.rttm\"\n",
    "    an4_rttm = wget.download(an4_rttm_url, data_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's plot and listen to the audio. Pay attention to the two speakers in the audio clip."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import IPython\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import librosa\n",
    "\n",
    "sr = 16000\n",
    "signal, sr = librosa.load(an4_audio,sr=sr) \n",
    "\n",
    "fig,ax = plt.subplots(1,1)\n",
    "fig.set_figwidth(20)\n",
    "fig.set_figheight(2)\n",
    "plt.plot(np.arange(len(signal)),signal,'gray')\n",
    "fig.suptitle('Reference merged an4 audio', fontsize=16)\n",
    "plt.xlabel('time (secs)', fontsize=18)\n",
    "ax.margins(x=0)\n",
    "plt.ylabel('signal strength', fontsize=16)\n",
    "a,_ = plt.xticks();plt.xticks(a,a/sr)\n",
    "\n",
    "IPython.display.Audio(an4_audio)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have downloaded the example audio file contains multiple speakers, we can feed the audio clip into Sortformer diarizer and get the speaker diarization results."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download Sortformer diarizer model\n",
    "\n",
    "To download Sortformer diarizer from [Sortformer HuggingFace model card](https://huggingface.co/nvidia/diar_sortformer_4spk-v1), you need to get a [HuggingFace Acces Token](https://huggingface.co/docs/hub/en/security-tokens) and feed your access token in your python environment using [HuggingFace CLI](https://huggingface.co/docs/huggingface_hub/main/en/guides/cli).\n",
    "\n",
    "If you are having trouble getting a HuggingFace token, you can download Sortformer model from [Sortformer HuggingFace model card](https://huggingface.co/nvidia/diar_sortformer_4spk-v1) and specify the path to the downloaded model.\n",
    "\n",
    "### Timestamp output and tensor output\n",
    "\n",
    "When excuting `diarize()` function, if you specify `include_tensor_outputs=True`, the diarization model will return the predicted speaker-labeled segments and tensors containing T by N (N is number of max speakers) sigmoid values for each audio frame. \n",
    "\n",
    "Without `include_tensor_outputs` variable or `include_tensor_outputs=False`, only speaker labeled segments will be returned."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.asr.models import SortformerEncLabelModel\n",
    "from huggingface_hub import get_token as get_hf_token\n",
    "import torch\n",
    "\n",
    "if get_hf_token() is not None and get_hf_token().startswith(\"hf_\"):\n",
    "    # If you have logged into HuggingFace hub and have access token \n",
    "    diar_model = SortformerEncLabelModel.from_pretrained(\"nvidia/diar_sortformer_4spk-v1\")\n",
    "else:\n",
    "    # You can downloaded \".nemo\" file from https://huggingface.co/nvidia/diar_sortformer_4spk-v1 and specify the path.\n",
    "    diar_model = SortformerEncLabelModel.restore_from(restore_path=\"/path/to/diar_sortformer_4spk-v1.nemo\", map_location=torch.device('cuda'), strict=False)\n",
    "\n",
    "pred_list, pred_tensor_list = diar_model.diarize(audio=an4_audio, include_tensor_outputs=True)\n",
    "\n",
    "print(pred_list[0])\n",
    "print(pred_tensor_list[0].shape)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Now let's visualize the predicted speaker diarization results. The diarization model outputs a tensor where each row represents a speaker and each column represents a frame. The sigmoid values in the tensor represent the probability of the frame being spoken by that speaker.\n",
    "\n",
    "In the following code, we'll plot the predicted speaker diarization results for the sample audio file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "def plot_diarout(preds):\n",
    "    preds_mat = preds.cpu().numpy().transpose()\n",
    "    cmap_str, grid_color_p= 'viridis', 'gray'\n",
    "    LW, FS = 0.4, 36\n",
    "\n",
    "    yticklabels = [\"spk0\", \"spk1\", \"spk2\", \"spk3\"]\n",
    "    yticks = np.arange(len(yticklabels))\n",
    "    fig, axs = plt.subplots(1, 1, figsize=(30, 3)) \n",
    "\n",
    "    axs.imshow(preds_mat, cmap=cmap_str, interpolation='nearest') \n",
    "    axs.set_title('Predictions', fontsize=FS)\n",
    "    axs.set_xticks(np.arange(-.5, preds_mat.shape[1], 1), minor=True)\n",
    "    axs.set_yticks(yticks)\n",
    "    axs.set_yticklabels(yticklabels)\n",
    "    axs.set_xlabel(f\"80 ms Frames\", fontsize=FS)\n",
    "    axs.grid(which='minor', color=grid_color_p, linestyle='-', linewidth=LW)\n",
    "\n",
    "    plt.savefig('plot.png', dpi=300) \n",
    "    plt.show()\n",
    "\n",
    "\n",
    "plot_diarout(pred_tensor_list[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that the first speaker always gets the first generic speaker label `spk0`. Sortformer model is trained to generate speaker labels in arrival time order, thus permutations of speakers are always predictable if we know the arrival time order of speakers."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Post-Processing of Speaker Timestamps\n",
    "\n",
    "In the previous steps, we have obtained predictions of speaker activity in a form of Tensors. While the floating point probability values can be interpreted as speaker probabilities, these values are not designed to consumed as is and still requires to be processed to be segment information which has start and end time for each time stamp.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from omegaconf import OmegaConf\n",
    "import os\n",
    "import wget\n",
    "import pandas as pd\n",
    "from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object\n",
    "from pyannote.metrics.diarization import DiarizationErrorRate\n",
    "\n",
    "ROOT = os.getcwd()\n",
    "data_dir = os.path.join(ROOT,'data')\n",
    "\n",
    "yaml_name=\"sortformer_diar_4spk-v1_dihard3-dev.yaml\"\n",
    "MODEL_CONFIG = os.path.join(data_dir, yaml_name)\n",
    "if True:\n",
    "    config_url = f\"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/post_processing/{yaml_name}\"\n",
    "    MODEL_CONFIG = wget.download(config_url, data_dir)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The post-processing yaml file `sortformer_diar_4spk-v1_dihard3-dev.yaml` is containing parameters that are optimized to have the lowest Diarization Error Rate (DER) on the [DiHARD 3 development set](https://catalog.ldc.upenn.edu/LDC2022S12)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nemo.collections.asr.parts.utils.vad_utils import load_postprocessing_from_yaml\n",
    "import json\n",
    "from omegaconf import OmegaConf \n",
    "post_processing_params = load_postprocessing_from_yaml(MODEL_CONFIG)\n",
    "print(json.dumps(OmegaConf.to_container(post_processing_params), indent=4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Post-processing of Speaker Timestamps\n",
    "\n",
    "The parameters in post-processing yaml configurations perform the following tasks:\n",
    "\n",
    "| **Parameter**      | **Description**                                               |\n",
    "|-------------------:|--------------------------------------------------------------|\n",
    "| **onset**          | Onset threshold for detecting the beginning of a speech segment. |\n",
    "| **offset**         | Offset threshold for detecting the end of a speech segment.      |\n",
    "| **pad_onset**      | Adds the specified duration at the beginning of each speech segment. |\n",
    "| **pad_offset**     | Adds the specified duration at the end of each speech segment.     |\n",
    "| **min_duration_on**| Removes short speech segments if the duration is less than the specified minimum duration. |\n",
    "| **min_duration_off**| Removes short silences if the duration is less than the specified minimum duration. |\n",
    "\n",
    "Now let's check the diarization output timestamps and compare how post-processing changes the diarization output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def show_diar_df(pred_session_list):\n",
    "    data = [segment.split() for segment in pred_session_list]\n",
    "    df = pd.DataFrame(data, columns=['Start Time', 'End Time', 'Speaker'])\n",
    "    print(df)\n",
    "\n",
    "# Call `diarize` method without postprocessing params\n",
    "pred_list_bn = diar_model.diarize(audio=an4_audio)\n",
    "print(f\"  [Default Binarized Diarization Output]: \")\n",
    "show_diar_df(pred_list_bn[0])\n",
    "\n",
    "# Call `diarize` method with postprocessing params\n",
    "pred_list_pp = diar_model.diarize(audio=an4_audio, postprocessing_yaml=MODEL_CONFIG)\n",
    "print(f\"  [Post-processed Diarization Output]: \")\n",
    "show_diar_df(pred_list_pp[0])\n",
    "\n",
    "print(f\"  [Ground-truth Diarization Output]: \")\n",
    "ref_labels = rttm_to_labels(an4_rttm)\n",
    "show_diar_df(ref_labels)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see that the post-processed segments have more on-set padding while the differences are not significant. \n",
    "\n",
    "Now let's calculate DER (Diarization Error Rate) between the predicted labels and the ground-truth labels for both raw binarized and post-processed diarization outputs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the refernce labels from ground-truth RTTM file\n",
    "ref_labels = rttm_to_labels(an4_rttm)\n",
    "\n",
    "reference = labels_to_pyannote_object(ref_labels, uniq_name=\"binarize\")\n",
    "hypothesis1 = labels_to_pyannote_object(pred_list_bn[0], uniq_name=\"binarize\")\n",
    "der_metric1 = DiarizationErrorRate(collar=0, skip_overlap=False)\n",
    "der_metric1(reference, hypothesis1, detailed=True)\n",
    "print(f\"Raw Binarization DER: {abs(der_metric1):.6f}\")\n",
    "\n",
    "reference = labels_to_pyannote_object(ref_labels, uniq_name=\"post_processing\")\n",
    "hypothesis2 = labels_to_pyannote_object(pred_list_pp[0], uniq_name=\"post_processing\")\n",
    "der_metric2 = DiarizationErrorRate(collar=0, skip_overlap=False)\n",
    "der_metric2(reference, hypothesis2, detailed=True)\n",
    "print(f\"Post-Processing DER:  {abs(der_metric2):.6f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since the diarization output with post-processing is optimized to get the lowest DER for the given sigmoid tensor outputs, it generaly gives lower DER values when compared to the raw binarized diarization output. To achieve the lowest DER, it is recommended to use the post-processing parameters that are optimized for your dataset of interest and your diarization model. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "metadata": {
     "collapsed": false
    },
    "source": []
   }
  },
  "vscode": {
   "interpreter": {
    "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
